{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "071c7845",
   "metadata": {},
   "source": [
    "下面介绍基于循环神经网络的编码器和解码器的代码实现。首先是作为编码器的循环神经网络。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f682c7d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "class RNNEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNEncoder, self).__init__()\n",
    "        # 隐层大小\n",
    "        self.hidden_size = hidden_size\n",
    "        # 词表大小\n",
    "        self.vocab_size = vocab_size\n",
    "        # 词嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size,\\\n",
    "            self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        # inputs: batch * seq_len\n",
    "        # 注意门控循环单元使用batch_first=True，因此输入需要至少batch为1\n",
    "        features = self.embedding(inputs)\n",
    "        output, hidden = self.gru(features)\n",
    "        return output, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0dc8af74",
   "metadata": {},
   "source": [
    "接下来是作为解码器的另一个循环神经网络的代码实现。<!--我们使用编码器最终的输出用作解码器的初始隐状态，这个输出向量有时称作上下文向量（context vector），它编码了整个源序列的信息。解码器最初的输入词元是“\\<sos\\>”（start-of-string）。解码器的目标为，输入编码器隐状态，一步一步解码出整个目标序列。解码器每一步的输入可以是真实目标序列，也可以是解码器上一步的预测结果。-->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9beee561",
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(RNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        # 序列到序列任务并不限制编码器和解码器输入同一种语言，\n",
    "        # 因此解码器也需要定义一个嵌入层\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.gru = nn.GRU(self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将输出的隐状态映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1,\\\n",
    "            dtype=torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden = self.forward_step(\\\n",
    "                decoder_input, decoder_hidden)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 为了与AttnRNNDecoder接口保持统一，最后输出None\n",
    "        return decoder_outputs, decoder_hidden, None\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden):\n",
    "        output = self.embedding(input)\n",
    "        output = F.relu(output)\n",
    "        output, hidden = self.gru(output, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38306ed9",
   "metadata": {},
   "source": [
    "下面介绍基于注意力机制的循环神经网络解码器的代码实现。\n",
    "我们使用一个注意力层来计算注意力权重，其输入为解码器的输入和隐状态。\n",
    "这里使用Bahdanau注意力（Bahdanau attention），这是序列到序列模型中应用最广泛的注意力机制，特别是机器翻译任务。该注意力机制使用一个对齐模型（alignment model）来计算编码器和解码器隐状态之间的注意力分数，具体来讲就是一个前馈神经网络。相比于点乘注意力，Bahdanau注意力利用了非线性变换。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d675aa6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class BahdanauAttention(nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super(BahdanauAttention, self).__init__()\n",
    "        self.Wa = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Ua = nn.Linear(hidden_size, hidden_size)\n",
    "        self.Va = nn.Linear(hidden_size, 1)\n",
    "\n",
    "    def forward(self, query, keys):\n",
    "        # query: batch * 1 * hidden_size\n",
    "        # keys: batch * seq_length * hidden_size\n",
    "        # 这一步用到了广播（broadcast）机制\n",
    "        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
    "        scores = scores.squeeze(2).unsqueeze(1)\n",
    "\n",
    "        weights = F.softmax(scores, dim=-1)\n",
    "        context = torch.bmm(weights, keys)\n",
    "        return context, weights\n",
    "\n",
    "class AttnRNNDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, hidden_size):\n",
    "        super(AttnRNNDecoder, self).__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding = nn.Embedding(self.vocab_size, self.hidden_size)\n",
    "        self.attention = BahdanauAttention(hidden_size)\n",
    "        # 输入来自解码器输入和上下文向量，因此输入大小为2 * hidden_size\n",
    "        self.gru = nn.GRU(2 * self.hidden_size, self.hidden_size,\\\n",
    "            batch_first=True)\n",
    "        # 用于将注意力的结果映射为词表上的分布\n",
    "        self.out = nn.Linear(self.hidden_size, self.vocab_size)\n",
    "\n",
    "    # 解码整个序列\n",
    "    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
    "        batch_size = encoder_outputs.size(0)\n",
    "        # 从<sos>开始解码\n",
    "        decoder_input = torch.empty(batch_size, 1, dtype=\\\n",
    "            torch.long).fill_(SOS_token)\n",
    "        decoder_hidden = encoder_hidden\n",
    "        decoder_outputs = []\n",
    "        attentions = []\n",
    "\n",
    "        # 如果目标序列确定，最大解码步数确定；\n",
    "        # 如果目标序列不确定，解码到最大长度\n",
    "        if target_tensor is not None:\n",
    "            seq_length = target_tensor.size(1)\n",
    "        else:\n",
    "            seq_length = MAX_LENGTH\n",
    "        \n",
    "        # 进行seq_length次解码\n",
    "        for i in range(seq_length):\n",
    "            # 每次输入一个词和一个隐状态\n",
    "            decoder_output, decoder_hidden, attn_weights = \\\n",
    "                self.forward_step(\n",
    "                    decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            attentions.append(attn_weights)\n",
    "\n",
    "            if target_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = target_tensor[:, i].unsqueeze(1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = topi.squeeze(-1).detach()\n",
    "\n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        attentions = torch.cat(attentions, dim=1)\n",
    "        # 与RNNDecoder接口保持统一，最后输出注意力权重\n",
    "        return decoder_outputs, decoder_hidden, attentions\n",
    "\n",
    "    # 解码一步\n",
    "    def forward_step(self, input, hidden, encoder_outputs):\n",
    "        embeded =  self.embedding(input)\n",
    "        # 输出的隐状态为1 * batch * hidden_size，\n",
    "        # 注意力的输入需要batch * 1 * hidden_size\n",
    "        query = hidden.permute(1, 0, 2)\n",
    "        context, attn_weights = self.attention(query, encoder_outputs)\n",
    "        input_gru = torch.cat((embeded, context), dim=2)\n",
    "        # 输入的隐状态需要1 * batch * hidden_size\n",
    "        output, hidden = self.gru(input_gru, hidden)\n",
    "        output = self.out(output)\n",
    "        return output, hidden, attn_weights\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d63b031",
   "metadata": {},
   "source": [
    "\n",
    "接下来我们实现基于Transformer的编码器和解码器。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c9a91cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.append('./code')\n",
    "from transformer import *\n",
    "\n",
    "class TransformerEncoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 直接使用TransformerLayer作为编码层，简单起见只使用一层\n",
    "        self.layer = TransformerLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 与TransformerLM不同，编码器不需要线性层用于输出\n",
    "        \n",
    "    def forward(self, input_ids):\n",
    "        # 这里实现的forward()函数一次只能处理一句话，\n",
    "        # 如果想要支持批次运算，需要根据输入序列的长度返回隐状态\n",
    "        assert input_ids.ndim == 2 and input_ids.size(0) == 1\n",
    "        seq_len = input_ids.size(1)\n",
    "        assert seq_len <= self.embedding_layer.max_len\n",
    "        \n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        attention_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        input_states = self.embedding_layer(input_ids, pos_ids)\n",
    "        hidden_states = self.layer(input_states, attention_mask)\n",
    "        return hidden_states, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e8cfc9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultiHeadCrossAttention(MultiHeadSelfAttention):\n",
    "    def forward(self, tgt, tgt_mask, src, src_mask):\n",
    "        \"\"\"\n",
    "        tgt: query, batch_size * tgt_seq_len * hidden_size\n",
    "        tgt_mask: batch_size * tgt_seq_len\n",
    "        src: keys/values, batch_size * src_seq_len * hidden_size\n",
    "        src_mask: batch_size * src_seq_len\n",
    "        \"\"\"\n",
    "        # (batch_size * num_heads) * seq_len * (hidden_size / num_heads)\n",
    "        queries = self.transpose_qkv(self.W_q(tgt))\n",
    "        keys = self.transpose_qkv(self.W_k(src))\n",
    "        values = self.transpose_qkv(self.W_v(src))\n",
    "        # 这一步与自注意力不同，计算交叉掩码\n",
    "        # batch_size * tgt_seq_len * src_seq_len\n",
    "        attention_mask = tgt_mask.unsqueeze(2) * src_mask.unsqueeze(1)\n",
    "        # 重复张量的元素，用以支持多个注意力头的运算\n",
    "        # (batch_size * num_heads) * tgt_seq_len * src_seq_len\n",
    "        attention_mask = torch.repeat_interleave(attention_mask,\\\n",
    "            repeats=self.num_heads, dim=0)\n",
    "        # (batch_size * num_heads) * tgt_seq_len * \\\n",
    "        # (hidden_size / num_heads)\n",
    "        output = self.attention(queries, keys, values, attention_mask)\n",
    "        # batch * tgt_seq_len * hidden_size\n",
    "        output_concat = self.transpose_output(output)\n",
    "        return self.W_o(output_concat)\n",
    "\n",
    "# TransformerDecoderLayer比TransformerLayer多了交叉多头注意力\n",
    "class TransformerDecoderLayer(nn.Module):\n",
    "    def __init__(self, hidden_size, num_heads, dropout,\\\n",
    "                 intermediate_size):\n",
    "        super().__init__()\n",
    "        self.self_attention = MultiHeadSelfAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm1 = AddNorm(hidden_size, dropout)\n",
    "        self.enc_attention = MultiHeadCrossAttention(hidden_size,\\\n",
    "            num_heads, dropout)\n",
    "        self.add_norm2 = AddNorm(hidden_size, dropout)\n",
    "        self.fnn = PositionWiseFNN(hidden_size, intermediate_size)\n",
    "        self.add_norm3 = AddNorm(hidden_size, dropout)\n",
    "\n",
    "    def forward(self, src_states, src_mask, tgt_states, tgt_mask):\n",
    "        # 掩码多头自注意力\n",
    "        tgt = self.add_norm1(tgt_states, self.self_attention(\\\n",
    "            tgt_states, tgt_states, tgt_states, tgt_mask))\n",
    "        # 交叉多头自注意力\n",
    "        tgt = self.add_norm2(tgt, self.enc_attention(tgt,\\\n",
    "            tgt_mask, src_states, src_mask))\n",
    "        # 前馈神经网络\n",
    "        return self.add_norm3(tgt, self.fnn(tgt))\n",
    "\n",
    "class TransformerDecoder(nn.Module):\n",
    "    def __init__(self, vocab_size, max_len, hidden_size, num_heads,\\\n",
    "                 dropout, intermediate_size):\n",
    "        super().__init__()\n",
    "        self.embedding_layer = EmbeddingLayer(vocab_size, max_len,\\\n",
    "            hidden_size)\n",
    "        # 简单起见只使用一层\n",
    "        self.layer = TransformerDecoderLayer(hidden_size, num_heads,\\\n",
    "            dropout, intermediate_size)\n",
    "        # 解码器与TransformerLM一样，需要输出层\n",
    "        self.output_layer = nn.Linear(hidden_size, vocab_size)\n",
    "        \n",
    "    def forward(self, src_states, src_mask, tgt_tensor=None):\n",
    "        # 确保一次只输入一句话，形状为1 * seq_len * hidden_size\n",
    "        assert src_states.ndim == 3 and src_states.size(0) == 1\n",
    "        \n",
    "        if tgt_tensor is not None:\n",
    "            # 确保一次只输入一句话，形状为1 * seq_len\n",
    "            assert tgt_tensor.ndim == 2 and tgt_tensor.size(0) == 1\n",
    "            seq_len = tgt_tensor.size(1)\n",
    "            assert seq_len <= self.embedding_layer.max_len\n",
    "        else:\n",
    "            seq_len = self.embedding_layer.max_len\n",
    "        \n",
    "        decoder_input = torch.empty(1, 1, dtype=torch.long).\\\n",
    "            fill_(SOS_token)\n",
    "        decoder_outputs = []\n",
    "        \n",
    "        for i in range(seq_len):\n",
    "            decoder_output = self.forward_step(decoder_input,\\\n",
    "                src_mask, src_states)\n",
    "            decoder_outputs.append(decoder_output)\n",
    "            \n",
    "            if tgt_tensor is not None:\n",
    "                # teacher forcing: 使用真实目标序列作为下一步的输入\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    tgt_tensor[:, i:i+1]), 1)\n",
    "            else:\n",
    "                # 从当前步的输出概率分布中选取概率最大的预测结果\n",
    "                # 作为下一步的输入\n",
    "                _, topi = decoder_output.topk(1)\n",
    "                # 使用detach从当前计算图中分离，避免回传梯度\n",
    "                decoder_input = torch.cat((decoder_input,\\\n",
    "                    topi.squeeze(-1).detach()), 1)\n",
    "                \n",
    "        decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
    "        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
    "        # 与RNNDecoder接口保持统一\n",
    "        return decoder_outputs, None, None\n",
    "        \n",
    "    # 解码一步，与RNNDecoder接口略有不同，RNNDecoder一次输入\n",
    "    # 一个隐状态和一个词，输出一个分布、一个隐状态\n",
    "    # TransformerDecoder不需要输入隐状态，\n",
    "    # 输入整个目标端历史输入序列，输出一个分布，不输出隐状态\n",
    "    def forward_step(self, tgt_inputs, src_mask, src_states):\n",
    "        seq_len = tgt_inputs.size(1)\n",
    "        # 1 * seq_len\n",
    "        pos_ids = torch.unsqueeze(torch.arange(seq_len), dim=0)\n",
    "        tgt_mask = torch.ones((1, seq_len), dtype=torch.int32)\n",
    "        tgt_states = self.embedding_layer(tgt_inputs, pos_ids)\n",
    "        hidden_states = self.layer(src_states, src_mask, tgt_states,\\\n",
    "            tgt_mask)\n",
    "        output = self.output_layer(hidden_states[:, -1:, :])\n",
    "        return output"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e58b811",
   "metadata": {},
   "source": [
    "下面以机器翻译（中-英）为例展示如何训练序列到序列模型。这里使用的是中英文Books数据，其中中文标题来源于第4章所使用的数据集，英文标题是使用已训练好的机器翻译模型从中文标题翻译而得，因此该数据并不保证准确性，仅用于演示。\n",
    "\n",
    "首先需要对源语言和目标语言分别建立索引，并记录词频。\n",
    "\n",
    "<!-- 代码来自：\n",
    "\n",
    "https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html <span style=\"color:blue;font-size:20px\">BSD-3-Clause license</span>\n",
    "\n",
    "下载链接：\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.en\n",
    "\n",
    "https://github.com/zwhe99/LLM-MT-Eval/blob/main/data/raw/wmt22.zh-en.zh\n",
    " -->"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6b258fc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "SOS_token = 0\n",
    "EOS_token = 1\n",
    "\n",
    "class Lang:\n",
    "    def __init__(self, name):\n",
    "        self.name = name\n",
    "        self.word2index = {}\n",
    "        self.word2count = {}\n",
    "        self.index2word = {0: \"<sos>\", 1: \"<eos>\"}\n",
    "        self.n_words = 2  # Count SOS and EOS\n",
    "\n",
    "    def addSentence(self, sentence):\n",
    "        for word in sentence.split(' '):\n",
    "            self.addWord(word)\n",
    "\n",
    "    def addWord(self, word):\n",
    "        if word not in self.word2index:\n",
    "            self.word2index[word] = self.n_words\n",
    "            self.word2count[word] = 1\n",
    "            self.index2word[self.n_words] = word\n",
    "            self.n_words += 1\n",
    "        else:\n",
    "            self.word2count[word] += 1\n",
    "            \n",
    "    def sent2ids(self, sent):\n",
    "        return [self.word2index[word] for word in sent.split(' ')]\n",
    "    \n",
    "    def ids2sent(self, ids):\n",
    "        return ' '.join([self.index2word[idx] for idx in ids])\n",
    "\n",
    "import unicodedata\n",
    "import string\n",
    "import re\n",
    "import random\n",
    "\n",
    "# 文件使用unicode编码，我们将unicode转为ASCII，转为小写，并修改标点\n",
    "def unicodeToAscii(s):\n",
    "    return ''.join(\n",
    "        c for c in unicodedata.normalize('NFD', s)\n",
    "        if unicodedata.category(c) != 'Mn'\n",
    "    )\n",
    "\n",
    "def normalizeString(s):\n",
    "    s = unicodeToAscii(s.lower().strip())\n",
    "    # 在标点前插入空格\n",
    "    s = re.sub(r\"([,.!?])\", r\" \\1\", s)\n",
    "    return s.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "01ae1e7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取文件，一共有两个文件，两个文件的同一行对应一对源语言和目标语言句子\n",
    "def readLangs(lang1, lang2):\n",
    "    # 读取文件，分句\n",
    "    lines1 = open(f'{lang1}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    lines2 = open(f'{lang2}.txt', encoding='utf-8').read()\\\n",
    "        .strip().split('\\n')\n",
    "    print(len(lines1), len(lines2))\n",
    "    \n",
    "    # 规范化\n",
    "    lines1 = [normalizeString(s) for s in lines1]\n",
    "    lines2 = [normalizeString(s) for s in lines2]\n",
    "    if lang1 == 'zh':\n",
    "        lines1 = [' '.join(list(s.replace(' ', ''))) for s in lines1]\n",
    "    if lang2 == 'zh':\n",
    "        lines2 = [' '.join(list(s.replace(' ', ''))) for s in lines2]\n",
    "    pairs = [[l1, l2] for l1, l2 in zip(lines1, lines2)]\n",
    "\n",
    "    input_lang = Lang(lang1)\n",
    "    output_lang = Lang(lang2)\n",
    "    return input_lang, output_lang, pairs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "d6a561e0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n",
      "['黑 白 摄 影 教 程', 'black & white photography academy']\n"
     ]
    }
   ],
   "source": [
    "# 为了快速训练，过滤掉一些过长的句子\n",
    "MAX_LENGTH = 30\n",
    "\n",
    "def filterPair(p):\n",
    "    return len(p[0].split(' ')) < MAX_LENGTH and \\\n",
    "        len(p[1].split(' ')) < MAX_LENGTH\n",
    "\n",
    "def filterPairs(pairs):\n",
    "    return [pair for pair in pairs if filterPair(pair)]\n",
    "\n",
    "def prepareData(lang1, lang2):\n",
    "    input_lang, output_lang, pairs = readLangs(lang1, lang2)\n",
    "    print(f\"读取 {len(pairs)} 对序列\")\n",
    "    pairs = filterPairs(pairs)\n",
    "    print(f\"过滤后剩余 {len(pairs)} 对序列\")\n",
    "    print(\"统计词数\")\n",
    "    for pair in pairs:\n",
    "        input_lang.addSentence(pair[0])\n",
    "        output_lang.addSentence(pair[1])\n",
    "    print(input_lang.name, input_lang.n_words)\n",
    "    print(output_lang.name, output_lang.n_words)\n",
    "    return input_lang, output_lang, pairs\n",
    "\n",
    "input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "print(random.choice(pairs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cfa6874b",
   "metadata": {},
   "source": [
    "为了便于训练，对每一对源-目标句子需要准备一个源张量（源句子的词元索引）和一个目标张量（目标句子的词元索引）。在两个句子的末尾会添加“\\<eos\\>”。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "92baa7c9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2157 2157\n",
      "读取 2157 对序列\n",
      "过滤后剩余 2003 对序列\n",
      "统计词数\n",
      "zh 1368\n",
      "en 3287\n"
     ]
    }
   ],
   "source": [
    "def get_train_data():\n",
    "    input_lang, output_lang, pairs = prepareData('zh', 'en')\n",
    "    train_data = []\n",
    "    for idx, (src_sent, tgt_sent) in enumerate(pairs):\n",
    "        src_ids = input_lang.sent2ids(src_sent)\n",
    "        tgt_ids = output_lang.sent2ids(tgt_sent)\n",
    "        # 添加<eos>\n",
    "        src_ids.append(EOS_token)\n",
    "        tgt_ids.append(EOS_token)\n",
    "        train_data.append([src_ids, tgt_ids])\n",
    "    return input_lang, output_lang, train_data\n",
    "        \n",
    "input_lang, output_lang, train_data = get_train_data()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "411c9ed3",
   "metadata": {},
   "source": [
    "<!--训练时，我们使用编码器编码源句子，并且保留每一步的输出向量和最终的隐状态。然后解码器以“\\<sos\\>”作为第一个输入，以及编码器的最终隐状态作为它的第一个隐状态。-->\n",
    "接下来是训练代码。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8b85d1ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch-19, loss=0.0087: 100%|█| 20/20 [12:00:53<00:00, 2162.6\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAioAAAGwCAYAAACHJU4LAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABKmUlEQVR4nO3dd3xT5f4H8E/SkZaO0AItdFBGWaXsAhZQQBmyBFyIiqBXvXhZiqLiBAfF8VMcgPMiXC/CvSLIFZnKEKGMAlLKhi7oAjrSmbbJ8/uj9NC0SZumSU+Sft6vV16v5JznnHxPD9APz3nOcxRCCAEiIiIiO6SUuwAiIiIiUxhUiIiIyG4xqBAREZHdYlAhIiIiu8WgQkRERHaLQYWIiIjsFoMKERER2S1XuQtoCL1ej7S0NPj4+EChUMhdDhEREZlBCIH8/HwEBQVBqay9z8Shg0paWhpCQ0PlLoOIiIgskJqaipCQkFrbOHRQ8fHxAVBxoL6+vjJXQ0RERObQaDQIDQ2Vfo/XxqGDSuXlHl9fXwYVIiIiB2POsA0OpiUiIiK7xaBCREREdotBhYiIiOwWgwoRERHZLQYVIiIislsMKkRERGS3GFSIiIjIbjGoEBERkd1iUCEiIiK7xaBCREREdotBhYiIiOwWgwoRERHZLYd+KKGtlJTpkF1YCqVCgdZqD7nLISIiarLYo2LEr/HpGLT0dyz48S+5SyEiImrSGFSM8HRzAVDRs0JERETyYVAxwtO9IqgUM6gQERHJikHFiMoelaJSBhUiIiI5yRpUysvL8dprr6F9+/bw9PREhw4d8NZbb0Gv18tZFlQ3g4q2TN46iIiImjpZ7/p577338MUXX2D16tXo3r07jh49iscffxxqtRrz5s2TrS6Va0V+05YzqBAREclJ1qBy8OBBTJw4EePGjQMAtGvXDj/88AOOHj0qZ1lwvxlUSst56YeIiEhOsl76GTJkCH777TecP38eAPDXX39h//79GDt2rNH2Wq0WGo3G4GUL7FEhIiKyD7L2qLz00kvIy8tD165d4eLiAp1Oh3fffRdTp0412j4mJgaLFy+2eV1Sj4pODyEEFAqFzb+TiIiIapK1R2X9+vX4/vvvsXbtWhw7dgyrV6/Ghx9+iNWrVxttv3DhQuTl5Umv1NRUm9Slcq0YTCsEUKYTNvkOIiIiqpusPSoLFizAyy+/jIceeggA0KNHDyQnJyMmJgbTp0+v0V6lUkGlUtm8rspLP0BFr4q7K+/iJiIikoOsv4GLioqgVBqW4OLiIvvtye4uVYIKx6kQERHJRtYelQkTJuDdd99F27Zt0b17dxw/fhwfffQRnnjiCTnLglKpgJuLAmU6AS3v/CEiIpKNrEHls88+w+uvv45//OMfyMrKQlBQEP7+97/jjTfekLMsABW9KmU6HXtUiIiIZCRrUPHx8cGyZcuwbNkyOcswyt1VicJSBhUiIiI5cZSoCS7KiluS9bzph4iISDYMKiZUzp2iF0wqREREcmFQMeFmhwqDChERkYwYVExQ3uxRYU4hIiKSD4OKCUpe+iEiIpIdg4oJCunSj7x1EBERNWUMKiawR4WIiEh+DComVA6mFQwqREREsmFQMeFWj4rMhRARETVhDComSGNUmFSIiIhkw6BigoI9KkRERLJjUDGBY1SIiIjkx6BigjThm8x1EBERNWUMKibwWT9ERETyY1AxQckJ34iIiGTHoGICJ3wjIiKSH4OKCRxMS0REJD8GFROkMSp6mQshIiJqwhhUTLg1RoU9KkRERHJhUDGBU+gTERHJj0HFBGkeFfaoEBERyYZBxQQFb08mIiKSHYOKCQqOUSEiIpIdg4oJnEeFiIhIfgwqJlQGFSIiIpIPg4oJvPRDREQkPwYVE5Sc8I2IiEh2DComcMI3IiIi+TGomHBrHhWZCyEiImrCGFRMUPCuHyIiItnJGlTatWsHhUJR4zVr1iw5ywJQ9dKPvHUQERE1Za5yfvmRI0eg0+mkz6dOncLIkSPxwAMPyFhVBc6jQkREJD9Zg0qrVq0MPi9duhQdO3bE0KFDZaroFuXNviY+64eIiEg+sgaVqkpLS/H9999j/vz50viQ6rRaLbRarfRZo9HYrB4Fn55MREQkO7sZTLtp0ybk5uZixowZJtvExMRArVZLr9DQUJvVw0s/RERE8rOboPLtt99izJgxCAoKMtlm4cKFyMvLk16pqak2q6eyT4c9KkRERPKxi0s/ycnJ2LVrF3766ada26lUKqhUqkapqfKuH45RISIiko9d9KisWrUKAQEBGDdunNylSDjhGxERkfxkDyp6vR6rVq3C9OnT4epqFx08ADjhGxERkT2QPajs2rULKSkpeOKJJ+QuxQAnfCMiIpKf7F0Yo0aNsstxILzrh4iISH6y96jYK074RkREJD8GFRM44RsREZH8GFRMuDVGhUmFiIhILgwqJijZo0JERCQ7BhUTbs2jwqRCREQkFwYVExS89ENERCQ7BhUTFOClHyIiIrkxqJjAwbRERETyY1AxQSk9lVDeOoiIiJoyBhUTsgtLAQD/OZoqcyVERERNF4OKCT/GXQEA5BSVYdfpTJmrISIiapoYVMzw5JqjcpdARETUJDGoEBERkd1iUCEiIiK7xaBiJiEECrTlcpdBRETUpDComOm59ScQ+eZ2JKTlyV0KERFRk8GgYqZNJ9IAAF/vuyxzJURERE0Hg4oJH9zf0+jyxBtFjVwJERFR08WgYkKftn5Gl/+Vmtu4hRARETVhDCom5JeUyV0CERFRk8egYkKIXzOjywN9VY1cCRERUdPFoGKCa+VDCasZ3b11I1dCRETUdDGomODiYjyoGF9KREREtsCgYoKpHhUiIiJqPAwqJrgwqBAREcmOQcUEVyV/NERERHLjb2MT2KNCREQkPwYVC8Ql52DbqQy5yyAiInJ6rnIX4IjuW3kAAPD780PRoZW3zNUQERE5L/ao1JOo8v5KTrFsdRARETUFDCr1JKokFX3VD0RERGR1sgeVq1ev4tFHH0WLFi3QrFkz9O7dG3FxcXKXBQBYNqV3resZU4iIiGxL1jEqOTk5GDx4MIYPH46tW7ciICAAly5dQvPmzeUsS9IpsOb4k6q9KLwviIiIyLZkDSrvvfceQkNDsWrVKmlZu3btTLbXarXQarXSZ41GY8vyoFTUjCLFZTrpfdXQcr1Ai91nszC+ZxA83V1sWhcREVFTIeuln82bNyMqKgoPPPAAAgIC0KdPH3z99dcm28fExECtVkuv0NBQm9ZXqC2vsSwjr0R6/9KGeOn9g18exIIfT+LdX0/btCYiIqKmRNagcvnyZaxcuRKdOnXC9u3bMXPmTMydOxdr1qwx2n7hwoXIy8uTXqmpqTatL0NTUmNZuf5WL8q1fC1KbvawXL5WCADYeTrTpjURERE1JbJe+tHr9YiKisKSJUsAAH369EFCQgJWrlyJxx57rEZ7lUoFlUrVaPW5udTMcYcTsw0+F5fq4OF261KPu6vs45OJiIichqy/Vdu0aYOIiAiDZd26dUNKSopMFRky5wnK1XtdMvO0JloSERFRfckaVAYPHoxz584ZLDt//jzCwsJkqsiQOc/7OXT5hsHnUp3eVuUQERE1ObIGleeeew6xsbFYsmQJLl68iLVr1+Krr77CrFmz5CxL0q2Nb51tSsoZTIiIiGxF1qDSv39/bNy4ET/88AMiIyPx9ttvY9myZXjkkUfkLEsS6OtRZ5uSKrcrExERkXXJ/lDC8ePHY/z48XKXYTEte1SIiIhshreoNNDec9egr3LLstrTTcZqiIiInAuDSgOdTtdg2Id7pM++nrJ3UhERETkNBhUrSMkukt6X6wQ0JWUyVkNEROQ8GFTqMPeuTvVqn55Xgp6LdiDTyKy2REREVD8MKnWYP7KzRdttO5Vh5UqIiIiaHgYVMzwYFVLvbXRVBtgSERGRZRhUzPDu5B713uatX05j0eYEG1RDRETUdDComMHNRWnRJaDvDiRZvxgiIqImhEHFTCoLn4qs1wt8sP0sdiRwzAoREVF9cdIPM/UL87Noux2nM7F89yUAQNLScdYsiYiIyOmxR8VMUe380T2o7ocUVvdjXKoNqiEiImoaGFTqIcqCXpVdZ7Kk90LwTiAiIqL6YFCph1Jdw4LGZ79ftFIlRERETQODSj30ClE3aPuPdp63UiVERERNA4NKPYzp0QZe7i4N2kdJmc5K1RARETk/BpV6UHu64fgboxq0jz3nsupuRERERAAYVOrN3cL5VCrN/P4YB9USERGZiUHFAq5KRYO215brrVQJERGRc2NQscC/nxyIlt7uFm/PoEJERGQeBhULDOzQAkdeHWHx9rtOZ1qxGiIiIufFoGIhhcLyyz+lOvaoEBERmYNBRQYNG+FCRETUdDCoyID3/BAREZmHQUUGvDuZiIjIPAwqDTC+ZxuLthPsUyEiIjILg0oDDO3cyqLt9MwpREREZmFQaYA+bZtbtiGv/RAREZmFQUUG7FEhIiIyD4NKA3i4WfYkZVcX3qBMRERkDgaVBgjxa4bZw8MxPTqs3tsRERFR3WQNKosWLYJCoTB4tW7dWs6S6u2F0V2weGIkPpvax+xt9Lz2Q0REZBbZe1S6d++O9PR06RUfHy93SRaZ0CvI7LZ6DqYlIiIyi6vsBbi6mt2LotVqodVqpc8ajcZWZdmUjj0qREREZpG9R+XChQsICgpC+/bt8dBDD+Hy5csm28bExECtVkuv0NDQRqzUelKyi+QugYiIyCEohJDvOsTWrVtRVFSEzp07IzMzE++88w7Onj2LhIQEtGjRokZ7Yz0qoaGhyMvLg6+vb2OWblS7l7eY3TZp6TgbVkJERGS/NBoN1Gq1Wb+/Zb30M2bMGOl9jx49EB0djY4dO2L16tWYP39+jfYqlQoqlaoxS7SZvOIyqD3d5C6DiIjIrsl+6acqLy8v9OjRAxcuXJC7FJtbseei3CUQERHZPbsKKlqtFmfOnEGbNpY97E9uD/U3f8xMQUm5DSshIiJyDrIGlRdeeAF79+5FYmIiDh06hPvvvx8ajQbTp0+XsyyLLb2vJxJjxuLBqJA62xaX6RqhIiIiIscma1C5cuUKpk6dii5duuDee++Fu7s7YmNjERZWv5le7YlCocCbE7rX2W5HQmYjVENEROTYZB1Mu27dOjm/3ma8VK7wUbkiX2v68o63SvYpbIiIiOyeXY1RcSYHX7kLQ8Jbmlzv7cGgQkREVBcGFRvxVrmiZ4ja5HqVK3/0REREdeFvSxtydTH943VzUeJqbjFknG+PiIjI7jGo2NCdXQNMrjuRmovBS3/HxzvPN2JFREREjoVBxYa6BPrU2ebT3znxGxERkSkMKjbk5qIwq52eT1MmIiIyikHFhlxdlPjwgV51trv/iwONUA0REZHjYVCxsfv71T1L7bGUXJRwploiIqIaGFQaweFX78LuF4bV2mbzX2mNUwwREZED4axjjSDAxwOoY1ytjuNUiIiIamCPip34at9lHEvJwZWcIkz/52FczS2WuyQiIiLZsUfFTiReL8S9K24Nqh289HckLR1X53YlZTpczCqAQgGk5ZZgZESgLcskIiJqVAwqjahrax+czci36j6nfXsIR5JypM+/zBmCyGDTU/cTERE5El76aUTfPT7A6vusGlIA4GJWgdW/g4iISC7sUWlErdUeVtuXXi+QX1JeY/mVnCKrfQcREZHc2KNix55ecxSl5XoUl+qwcs8lvLvltLRu2j8PoddbO2psk6+tGV6IiIgcFXtU7NiO05n4Yu8lfFTlwYVT+rdFeIA3/rx4w+g2Cpg3bT8REZEjYFCxcx/vMny68gNfHED3INODZRXMKURE5ER46cfOiWrzwOUUlWH/xesm2zOnEBGRM2FQISIiIrvFoOJkeOmHiIiciUVBZfXq1diyZYv0+cUXX0Tz5s0xaNAgJCcnW604ZxTRxlfuEoiIiByGRUFlyZIl8PT0BAAcPHgQn3/+Od5//320bNkSzz33nFULdDZfPNrPpvtv5s7x0URE5Dws+q2WmpqK8PBwAMCmTZtw//334+mnn8bgwYMxbNgwa9bndEL9PW26/4gg9tgQEZHzsKhHxdvbGzduVMzjsWPHDowYMQIA4OHhgeJiPvW3NgobDyJ5fNURm+6fiIioMVnUozJy5Eg8+eST6NOnD86fP49x4yqe8puQkIB27dpZsz4iIiJqwizqUVm+fDmio6Nx7do1bNiwAS1atAAAxMXFYerUqVYtkIiIiJouhRDVpxRzHBqNBmq1Gnl5efD1dZyxGa9sjMfaQyk223/S0nE22zcREVFD1ef3t0U9Ktu2bcP+/fulz8uXL0fv3r3x8MMPIycnx5JdNimT+wTLXQIREZFDsCioLFiwABqNBgAQHx+P559/HmPHjsXly5cxf/58qxbojPR6h+3EIiIialQWBZXExEREREQAADZs2IDx48djyZIlWLFiBbZu3WpRITExMVAoFHj22Wct2t6RtGvpJXcJREREDsGioOLu7o6ioiIAwK5duzBq1CgAgL+/v9TTUh9HjhzBV199hZ49e1pSjsMJ9PWQuwQiIiKHYFFQGTJkCObPn4+3334bhw8flm5PPn/+PEJCQuq1r4KCAjzyyCP4+uuv4efnZ0k5DslbxRlkiYiI6mJRUPn888/h6uqKH3/8EStXrkRwcMXg0K1bt+Luu++u175mzZqFcePGSZPG1Uar1UKj0Ri8HFXsK3fJXQIREZHds+i/9W3btsUvv/xSY/nHH39cr/2sW7cOx44dw5Ej5s2mGhMTg8WLF9frO+wVe1SIiIjqZvFvS51Oh02bNuHMmTNQKBTo1q0bJk6cCBcXF7O2T01Nxbx587Bjxw54eJg3ZmPhwoUGdxVpNBqEhoZaVD8RERHZP4uCysWLFzF27FhcvXoVXbp0gRAC58+fR2hoKLZs2YKOHTvWuY+4uDhkZWWhX79bTxPW6XTYt28fPv/8c2i12hqhR6VSQaVSWVIyEREROSCLgsrcuXPRsWNHxMbGwt/fHwBw48YNPProo5g7dy62bNlS5z7uuusuxMfHGyx7/PHH0bVrV7z00ktm98wQERGR87IoqOzdu9cgpABAixYtsHTpUgwePNisffj4+CAyMtJgmZeXF1q0aFFjORERETVNFt31o1KpkJ+fX2N5QUEB3N3dG1wUWc5VqZC7BCIiIquxKKiMHz8eTz/9NA4dOgQhBIQQiI2NxcyZM3HPPfdYXMyePXuwbNkyi7d3NJ8/3AfPDKt7PE99qD3drLo/IiIiOVkUVD799FN07NgR0dHR8PDwgIeHBwYNGoTw8PAmFTQaanzPILx0d1er7pNPESIiImdi0RiV5s2b4+eff8bFixdx5swZCCEQERGB8PBwa9dH9aQXjCpEROQ8zA4qdT0Vec+ePdL7jz76yOKCqGGYU4iIyJmYHVSOHz9uVjuFgoM55cQeFSIiciZmB5Xdu3fbsg6yFuYUIiJyIhYNpiX7xR4VIiJyJgwqdqZHsLpB2zOmEBGRM2FQsSN+zdywebZ5M/ua0tAOlXnrjmPeOvPGIxEREdkag4od6RXaHAqFAsum9Mb4nm0s2kdDLv1cL9Di5xNp+PlEGvKKyizeDxERkbUwqNgR/c2MMalPMD5/uK/RNk/f0aHGMneXW6exIR0qxaU66b2SfzKIiMgO8NeRHegS6AMAmNwnqM62T97evsayrm188OfLdwIARAN6VK4VaKX3ecXsUSEiIvkxqNiBDf8YhJ/+MQiTegfX2TbAxwNn377bYJleCOlhhA0Zo1JSdqtHxYUPNyQiIjvAoGIHvFWu6NvWz+zJ8jzcXGosq9yytjEq2xMy8MAXB5CaXWR0vbZML72Pjvkd/d/dZVY9REREtsKg4gSEuDUjsKmYkl9Shr//Kw5HknLw2qZTNdZn5JXg8e+OGCy7lq+t0Y6IiKgxMag4gGeGdYSXuwueHdHJ6PqKoHLr/cFLNwAA64+kYP2RFFy+VoAei3ZI7Y2NP3lqzVHrF05ERNRAFj09mRpXVJgfXhjVxeS4ER8PVyirXDaa+nUs/npzFF7aEF/xeUBbg/bVd1Ou0yP+ap51iyYiIrICBhU7tmxKb/x1JRfDuwRAWcvg1vfv74lynd5g2aHLN6T3l64VGKyrvEyUU1iKw0nZeOb7OCtWTUREZD0MKnZsUp9gTOpT951AYS28kKkpMVj29L9uhY/DidkG65QK4HxmPkZ9vK/OfecVlUHdzM3MiomIiKyLY1Qc1NDOrQw+1+dmYoVCgTd/TjCr7ewfjmHcp38gLjm77sZERERWxqDioFp4uxsuqEdSSc0uwsEql4Zq88eF60hI0+C+lQfrUR0REZF1MKg4qMoBsr1DmwMAVK4151YxJT2vpO5GREREdoBjVBxU/3b+OPDynWjlowIAqD05joSIiJwPg4oDC2ruKXcJRERENsVLP0RERGS3GFSIiIjIbjGoEBERkd1iUCEiIiK7xaBCREREdotBhYiIiOwWgwoRERHZLVmDysqVK9GzZ0/4+vrC19cX0dHR2Lp1q5wlERERkR2RNaiEhIRg6dKlOHr0KI4ePYo777wTEydOREKCeQ/MIyIiIucm68y0EyZMMPj87rvvYuXKlYiNjUX37t1lqoqIiIjshd1Moa/T6fDf//4XhYWFiI6ONtpGq9VCq9VKnzUaTWOVR0RERDKQfTBtfHw8vL29oVKpMHPmTGzcuBERERFG28bExECtVkuv0NDQRq62aRNCyF0CERE1MbIHlS5duuDEiROIjY3FM888g+nTp+P06dNG2y5cuBB5eXnSKzU1tZGrbdqYU4iIqLHJfunH3d0d4eHhAICoqCgcOXIEn3zyCb788ssabVUqFVQqVWOXSDcxpxARUWOTvUelOiGEwTgUsh+89ENERI1N1h6VV155BWPGjEFoaCjy8/Oxbt067NmzB9u2bZOzLDKBMYWIiBqbrEElMzMT06ZNQ3p6OtRqNXr27Ilt27Zh5MiRcpZFVrT3/DUEN/dAeICP3KUQEZEDkjWofPvtt3J+vdPZ/uwd+D42GVtPpeN6QalF+7inVxA2/5VmdF1JmQ5uLuZfLUxIy8P0fx4GACQtHWdRPURE1LTZ3RgVslyX1j54e1IkWqs9TLYZ1LFFrfv4dGof3Nc3xOg6fT2v/ZxNz6/fBkRERNUwqDghV6Xp09ortDmeH9m51u0/fKAnFozuUmO5QtHg0iQlZTpoy3XW2yERETklBhUn5Ko0nSiae7phzl2dat1eoVDA16PmVUF9PbtUTOWlcp0efd/eiX5v74Kuvt00RETUpDCoOKGrucXS+xmD2hmsm17tsykKI90ne85dq1cdChgPTNmFpSgq1aFAW45r+bwVnYiITGNQcUKdAm/dYZOSXWSwzsPNxax9KI0ElWfXn6hXHSYvFVVZ/tPxK/XaJxERNS0MKk6oaj44nVb3gxvbGBl8a43xKMbCDgB8uz9Rep9fUt7wLyIiIqfFoOKEqg5R8XS/1YPSoZVXjbbeKlejF2isOG62hrWxKdL772OTbfhNRETk6GR/1g9ZX9XxqVUHqwY396zRVqkwPh6lPkNcj6fkoFCrQ5vmHmjfwgtKI4N5hRDS9+Rrb/WisEeFiIhqw6DihJpV6UVJyS7Cuqdvwz/3J2LRPd1rtHV3VcLNpWH9J5NXHJDeT7stDG9PiqzRRlNSDk83F/znqOETr9Webg36biIicm689OOEqk/YdluHFvjqsSgEVelR+eSh3gjwUeHrx6IwsH3FJHBVA4ul0eVfNy/lpOUWY/H/TkvLey3egb5v78Rrm07VWisREVFV7FFxQnd1C6izzcTewbinVxAUCgXCA7wR7OeJcT3b1Pu7TM2t8tSao7heYHjrcYG25mUeY+NmiIiIKjGoOKGqY04GtPevs52Phxvm1jEJnCl6YTyoJJhxtxEABPvVHDdDRERUiZd+nFzPYLVN968zEVTMxolpiYioFgwqTs6lgQNlTcnIK8GZdA0amlNM9cgQEREBvPTj9NxqeUChpXR6gdtifgMAbH/2jgbti4/6ISKi2rBHxUl1bV0xjf6EXkEWbW9qZtrluy+ix6Lt0ucz6eaNRTGFDyUkIqLasEfFSW2ePQQ5RaUI9K05Pb45TF2R+WD7OYPPJ1JzLdp/JV76ISKi2rBHxUm5uyotDimA+WNcvzuQVGPZ4v8lmP097FEhIqLaMKiQUZmaEum9u0v9/pis+jPJ7LZzfjiOI0nZ9do/ERE1HQwqZNT+C9el9x9P6W3T73rgi4M23T8RETkuBhUyqupg2np2qBAREVkNfwWRUVVntzX2dGUiIqLGwKBCRkW08ZXeKxlUiIhIJgwqZFRwlSctK5lTiIhIJgwqZFTVThT2qBARkVwYVKhOgk8OJCIimTCokFEuVa737EjIlLESIiJqyhhUyKgHokLR1r8ZZgxqB01Jmc2/L7uw1ObfQUREjofP+iGjvFWu2LtgGBQKBWatPWbz7yvX623+HURE5HjYo0ImVc6fwsG0REQkF1mDSkxMDPr37w8fHx8EBARg0qRJOHfuXN0bUqNyc2lYUKk6J0ulILUHRkYESp/5cEIiIjJG1qCyd+9ezJo1C7Gxsdi5cyfKy8sxatQoFBYWylkWVfPIwDCLt01aOg6/zru9xnKFQoGvH4uCyrXijyCDChERGSPrGJVt27YZfF61ahUCAgIQFxeHO+64Q6aqqDq1p2V/TGp76vLV3GIAt+4usvYQlZV7LuGXk2lY+9RtUHu6WXfnRETUaOxqjEpeXh4AwN/f3+h6rVYLjUZj8CLbs/RZP3pRdy+Jy81968xoWx/vbTuLhDQN/rk/0ar7JSKixmU3QUUIgfnz52PIkCGIjIw02iYmJgZqtVp6hYaGNnKVTZOlI1TMCR/Kmz0qtrr0oy3n3URERI7MboLK7NmzcfLkSfzwww8m2yxcuBB5eXnSKzU1tRErbLos7VExp5NEuvRjRuOSMh0mLv8TOxIyDJbnFZehUFtudBtz9ktERPbLLoLKnDlzsHnzZuzevRshISEm26lUKvj6+hq8yPZseXNy5a3P5vSojP3kD/yVmoun/xUnLSsp06HX4h3o/uZ2nMvIx//+SjPYZs3BJKvWS0REjUvWwbRCCMyZMwcbN27Enj170L59eznLITN0CvDGhawCs9q+Pcn4JbyqKsfbmhNULl+veTfYlZxi6f3oZfsAwGDwbEkZL/0QETkyWXtUZs2ahe+//x5r166Fj48PMjIykJGRgeLi4ro3Jll8PKU3gtQedbY7+toITLvN9G3NLb3dAdwaTGvpJRpjc7w89s/DFu2LiIjsj6xBZeXKlcjLy8OwYcPQpk0b6bV+/Xo5y6Jqqg5R6dDKCwcW3oWhnVsZtJncJxhfPNoXADDvrk5o6a2qdZ+VY1NcXBo2mPZwYrZF2xERkWOQ/dIPOZbKMSUTewdh7/lr6BTgjZ3zh0rrTy0eDW9VzT9Wvh6u0JTcGvDqqqzIyC61jFHZdToTT645incnR9aYdE4IAYVCgT3nrtVa74NRpsc8ERGR/eNDCalOrkYmbpvcJxhhLbzQOdDbYLmxkAIAm2cPwQ+HU7Dx+FVk5WsxqU8QgNpvT35yzVEAwKsbT9WYtO2pNXG4fL0Al6/VPotxmY5hmIjIkTGoUJ2C1B6Y1DsInu4u8HBzAVBxy3K/MD+z99GupRcWju2GfwwLR2ziDQzvEgDA/AnfZq89bvB515lMs77XVHAiIiLHwH/FqU4KhQLLHupjlX2pm7lhdPfW0mdbTaFfiQ9+JiJybHYxjwo1XcoqPSqakjI+nJCIiAywR4VkVdmjkpJdhOk3byu+t08wIoKsM5kfO1SIiBwbgwrJqnIw7eubTknLfjp+FT8dv2qV/bN/hojIsfHSD8nKyHxtREREEgYVklXlpR9bYQ4iInJsDCokK1sHFSIicmwMKiQrBhUiIqoNgwrJSsmJToiIqBYMKiSrpBu1T4FPRERNG4MKySo1u1juEoiIyI4xqJBTs+Y8KjG/nsGnv12w4h6JiKgunPCNnJq+jocdmitTU4Iv910GADx9Rwfp4YxERGRb7FEhp7L2qYGY2DtI+lzXo4P0eoHDidkoKi2vtV15lR2V6mz0BEUiIqqBQYWcyqCOLRHeylv6LOroUVl1IAkPfnkQQ97bjdTsIpPtqt5FrdNxYn4iosbCoEJOTV9H58e6wykAgOzCUtz+/m4AwMbjV/D0mqMo1N7qZamad8rq2ikREVkNgwo5lKkDQuvVXmeiR0WnFzh5Jdfgkg5QcSnoufV/YcfpTHx3IAlARa9MWZXLPQOX/AZtua5+hRMRkUUYVMihvDK2W73amxpM+8mu87jn8z+ReN1wHpeqweX72GQAwN//FYehH+yRlgsBrDucWq86iIjIMgwqJKuurX2k924uCjx9RweTbd+a2B0+Hm517rNqJ4mpISqf/n7R6HJdlY3T80oAADtOZ9Zo9+bmhDrHvxARUcPx9mSS1buTe+C+lQcAALOGh+PZEZ0x765O8FK5YtWfiQjxa4aREYEG27x0d1e8t+2syX1uOHZFeq/TCxSX6vBjXCru7BaI4OaeAAB3F6XRu3eOJGWbXbuC0/8TEdkce1RIVv3C/KT3D0ZVjD/xUlXk58cHt68RUgDgmWEdDW5Bri6lyt07eiHwwfZzeP3nBIz/9A9pualbjB/752Gzay8t56BaIiJbY48Kye7ykrEoKtPBW2X+H8fF93SHUqHAA/1C8PA3hwzW/f2ODtLkbGcz8nE2Ix8AkFNUVu/aaru8Y63J5IiIyDQGFZKdUqmoV0gBgObN3PHxlN5G1/Vp21x6fzGrAGEtmllcW/uFv5pcV/2OISIisj5e+iGnoXKt+OM8MqK1wfLkG4YTuVlrEGw5Z6glIrI59qiQ03FRKuDj4Yr8kprT4i/dehbZhVqrfA97VIiIbI89KuTwxvdsA6BibEqlHsFqo22/2HsJ/zl6xei6+irnVPpERDbHHhVyeB892Bt/G9IePUOaS8s6tPLCgUs3bPq95ZxKn4jI5hhUyOG5uyrRp62fwTJXpe07C3W89ENEZHO89ENOqZm7i82/43K16fcbIr+kDCVlfH4QEVF1sgaVffv2YcKECQgKCoJCocCmTZvkLIecSG1T8VuLfzN3q+ynUFuOHot2oOvr26yyPyIiZyJrUCksLESvXr3w+eefy1kGOaHmVgoRtXFRWmcK/coJ6QDg8rUCq+yTiMhZyDpGZcyYMRgzZozZ7bVaLbTaW7eWajQaW5RFZBZLZ6b95/5ErNhzEeN7BuHZEZ0M1h1OzEaHVt7WKI+IyCk41BiVmJgYqNVq6RUaGip3SdSEWTqY9q1fTuN6QSm+O5CEVzedMljH8blERIYcKqgsXLgQeXl50is1NVXuksiO2XqcijWe9XMsOQdVH8J8Ki2vwfskInImDhVUVCoVfH19DV5Eplj79uHJfYIxtkdrqD3dbu7fOvudUeWJzWsPpaC4lHf/EBFVcqigQlQflj6LZ+UjfWsse25EZ3w8pTdWPNIPLb0rBupmF2qx9lAKNCWGT2U+eOkGHvrqIC5m3RoYm5ZbjAmf7cd/jhj2AmZqSqCpNtX/w9/EWlQ3EZEz4oRv5LS05ZYFlTE92kjvnxvRGfOqDXitvNtn5vfHAAD7zl/DF9P6AQD0eoGpX1cEjSdXH8GeBcMBAJ/suoD4q3l4ccNJg30Z6/Q5npJrUd1ERM5I1qBSUFCAixcvSp8TExNx4sQJ+Pv7o23btjJWRs4gLjnH4m39mrkhp6gMIyICaqxTKgxvS96WkCG9f2PzrcGxSTeKIIRA8o0i7L943eJaiIiaMlmDytGjRzF8+HDp8/z58wEA06dPx3fffSdTVeQsLmRZPifJvheHI1OjRXhAzVuFa5s/5fvYFIPPn/x2Act2XbC4DiKipk7WoDJs2DAIK9w5QWRtPh5u8PFwM7quPhO9MaQQETUMB9OS05rcJ7jW9d6qmjn9uRGd69xv9Us/AHDvij/R7uUt5hdHRERmYVAhp/XG+Ai09W+GF+/uUmPd1AFtEb9olPR5enQY7uwagCdvb1/nfo31qByz4gBYYwGKiKip4r+I5LT8vNyx78WKMVDvbzsHoKLHZOqAULTyUUGhUODfTw5ESnYRpg4wf/C2lR7xY9KdXWsO4K1UrtPjbEY+Itr4QmnrQoiI7ACDCjUp0R1bIMDXQ/o8OLwlBtdzH8Yu/VhTWS3zv/z9X3H47WwWxvZojRWP9LNpHURE9oCXfqhJ8PWoyOTd2vg0eF/WemqyKbUFld/OZgEAfo3PMNmGiMiZsEeFmoTDr45AqU5v8k6e+igus+0U96U6690Jl5ZbjNNpGoyICLTaPomIGhODCjUJHm4u8HBzscq+bD1zbFk9Z9Q9cOk6Hv76kPQ5aek46f2gpb8DAGYN74gFo7tap0AiokbESz9EdiDARyW9r+3SjzFVQwoAFGornh205WS6tGz57ksNqI6ISD4MKkQy+2XOENzXL0T6XD2oCCGkiRHdXer+K7vx+FUAwKy1xwyWL9990VhzIiK7xks/RDK6tGQsXJQKdGjlhX3nryEhTYO/ruQBAP5vxzlsiU+Hj4cbFACGdwlAqRm9La9tOoXXNp2qsfyD7ecwpX8oWnqrjGxFRGSf2KNCVE+h/p4N3odSASQsHi3dQdTM3RUPVZnL5a/UXHz2+0VcvlaIv1JzcSI1Fx/vOt/g7z2bnt/gfRARNSYGFaJ66hLY8FucVzzSD17VZqBtXWV+l4nL/6xzH3q9QGk9B9629HGvV3siIrnx0g9RPe06k1Vnm6kDQvHD4VSDZV0CffCfmdG4mJWPvm39amxjxvATA/d/cQBXcorrtU1ZuXVufU7NLkIbtQdc61s0EVE9MagQ2cDE3sFQKBRwVSqw5mAyAEAvBNSebugX5m90my6tfev1HZY8X8icMS4AUFKmw5Jfz+DOrgEY1iUAp67moXkzN4T4NcMfF65h2reHEerviadu74CpA9rCjYGFiGyE/7oQ1dO028Kk96NuTqT27ycHGrTRluuxZHIPvDUxEn3aNgcAPBAVgtp4Wmmel9qYe+vz+9vOYc3BZMxYdQTJNwox/rP9GPLebhxPycHKPRW3OqdmF+ONnxPQ6dWt9b4ERURkLvaoENWTr+etvzZfPRYFACgqLTdo0y/s1qWdf/1tIE6m5mJghxa17tfd1fb/bzAnqKw5mIR//pkofa683RkAJq84gLb+zWps89z6E1j+SF/rFElEVAV7VIjq6W9DOqBdi2aYe1cnaZnK9VZvyINRIfCuMlDWW+WKQeEt63xGkDlzpFji9k4tpffmBJU3fk4w+Lxs1wWDzynZRTW22RKfXmMZEZE1sEeFqJ78vdyxZ8Fwg2VVQ8jLY7pZtF9r9ajc1zcEG45dkT6vfnwABsb8hmv5WpSaGEybcqMIj393uMadSEREcuO/SkRWcuTVESgp08Hfy/JbgL3cXVBYWv+HHnZo6YUdz92BS9cK0TnQG+9OjsTstccwoVcQlEoFwlt541q+FmU6PYQQiL2cjQHt/ZFdWIpFmxNk7RG5UaBFi5uT0On1Asqboe9GgRY5RaUID2j47eBE5LgYVIispJVPw2d81Vo4KHVweEu4uijRpXXFL3UPNxd8M72/tN7tZm9NobYc7Rf+2uA6a3PySi6KS3V1jskBgEWbE/DdgSQAwOvjI/D2L6cR1qIZkm/curz0+/ND0aGVt63KJSI7xzEqRHakXG94aWZAO3/4ehj+f2Jw+K0A8FD/UPQKUWPB3V1q3a+7S0Uvxcs/xVupUkNe7hVjdIQQuOfzPzHlq1jcKNAiIS0PA5fswr8PJRvdrjKkAMDbv5wGAIOQAlh2GzYROQ/2qBDZqZ9nDUav0ObI0pRgwJLfAADn3rkbf6Xm4c+LBwEAS+/rada+bH1HUeDNWXULtLfufsrUaPHh9nPI1Gjx6sZTCPTxwPCuASgsLcdz605AU1Jm1r5d6xiELLdtpzIAAHdHtpa5EiLnxKBCZEdmDGqH7w4kYXzPNugV2hwAEODrgQ/u7wlvlStUri7o384Pfx/aAR3rcTkkS6O1Sb0b/zEIk1ccgLZcj/ySMvRYtENaN/bTPwzaPrnmqEXf4elu3vwy1/K18HBTwsfDDecz8+HmokT7ll412p26moetp9Lx7IjODZ6orkBbjpnfxwEAdjx3Bzpb4fEKRGSIQYXIjrw6rhvG92yDniHNDZY/EBUqvVcoFFhYzzuLjibnWKM8AwE+Kqmn5mpusUFIsSZtuR4F2nK8vOEkZg0PR7c2hjP4puUW42hyDub+cBwuSgWOvzESoz7eBwA4/Mpd8PZwRcQb2wEAiyZEYNH/Ki4xLd99CUlLxzWottQqt2qP+ngf7usbgv97sFeD9klEhhhUiOyIm4sSUe2MT7Evt3v7BuNvQ9rj898vYuupDPz2/FBkakps/r0lZTpEvlkRNH45mY7hXVrhqTs64OGvD+G/M6PxwBcHpbY6vcCpK3nS58pLZpUqQ0qlcxn50gDkSvsvXEcLb/cagaiSEAJxyTnoFOCDMZ8Y9hptOHYFG45daXAAIqJbFEII6zylTAYajQZqtRp5eXnw9a3fc1KImpL2C7fAnL/pSUvH4Vq+FpmaEiReL8SYyNbQC+Bocjb6tvWDR7Vp/lNuFOGOD3bbqOoKXVv74GxGvs32v+KRvhjbow2AijuW7vm84snViTFjoVAokKkpwfP/+Qsx9/bAGz+fwu5z1+rc56FX7pLG7ZgrU1MChQII8KnYrqRMhzKdHtmFpRj6wR64KhW4uGRsPY9OHk+uPopdZzJx4d0xDb689p8jqUjNKcLzo2ofME6OpT6/vxlUiJqA4lId/v59HPadv4ZXx3bDff1C4OaiwI9xV3Docja2JWSgZ4gam2cPqdd+M/JKcFvMb3U3tHOPRYdhbI82eOirWIPlDw9si7WHUuq9vxdGdUbi9SJsOHYFT93eHs8MC4e/lzueW38Ccck52LtgGPaev4YZq47g2RGd4K1yxTtbzgAAVs3oj8e/O2J0vyMjAvH1zcc22LN2L2+R3jekd0mnF+j4SsXt9H+8OByhRh7f0FQIIaBQ2PfA8vqoz+9vXvohagI83V2w5okBKC7VGQxOfXxwe4zvGYSubXzwWHS7eu/XzaXh/3DOGt4R3dr4YlyPNjicmI1v9ydi+SN90enVrQ3et7nWHEyWnnJdlSUhBai4C+rDHecBAF//kYiv/0iEh5sSJWUV8+T89+gVvLjhJICajygwFVIAYOfpzHrVIYRAblEZmjdzk37Jlev00Atge0IG2rXwQmRwxS8JU78EdXqB/Revo1eIGmpPN6TnlaCVjwrbEzIwKqK1NE4pLbcYzZu5oUxn+H9fIQT0AnU+QqLSxax8BPh6oFwnDO5Wu5pbjJEf70VJmR5TokLx3v237ng7kpSNz3+/iG+mR1nUg6PTC+QWleKD7eew7kgq1j45EIPCW9a9oQnach30evMHgtdl2a7zWLbrAsZEtsbKR/tZZZ+mlOv0OJ6ai6gwP7sJRuxRIaIGWb77Ij7Yfq7WNv5e7sgvKcMPT92Gn45fRa8QNd755QxGRgTioym9jW5T9X/lDeXhpoSrUmlw+7QtvTMpEq9tOmXT7zj/zhhsPZWOH+Ou4I8L1wEAC0Z3QfKNQkSF+ePB/qH1+hm+OSECi/93GnGvjUALbxWW776IS1kF8PFwxeqbIc7YZbjg5p6Y0j8UH+08X+v+Ky+lFZWWY9LyP3E+swBrnxqI9i29kKXRIqxFMxxLycET35l/d9jR10agpbfK4Dj/8/douLoo8FdqLmIv38CX0271QJWU6eDuooRSqYBOL6TwZOznVFdPUIG2HHPWHsMrY7uhYytv5BSVooW3CqXlenR+rSJk92/nhyNJOfjf7CH460ouXtt0CptmDUbvm3f0matqfYkxY3ElpxgHL9/A6O6tUagtR1BzT5Pb6vUC//j3MUzuG4zR3eu+hb7yMrGPhytOvjnKZmGFl36IqFHp9AIKACXlOuw8nYmB7VugpEyH5/5zAs8M7YhRZvwDWd3e89cw/Z+HzW7/1sTuBg9UTFg8Gj8dv4qQ5p4Y3jUAAJCVX4InVx/FySoDbq0hduFdUCqA+f/5C/svXrfqvp3FlrlDcDY9H8//9y+r7veLR/tJt4gb87ch7fHauG64klOM29/fjVB/T6RmFwMApg5oixmD2mH0sn01ttv+7B04kpSNlt7uaOGtwgNfHMT06DDMG9EZ+SVlGPrBHotrXnpvDzw0oK0Uatq1aIbdLwyTQoEQAlO+jIWmpMzs8Vm/zBmCyGC19LnycRQz/xWHbQkVc/0ceXUEbhRq8dSao3j6jo543YwwvWB0F8waHm7BUdbOoYLKihUr8MEHHyA9PR3du3fHsmXLcPvtt5u1LYMKkXO7kJmPkR/vw+Q+wUi8XghtuR4uSuDUVY1Bu1/n3o6IoIp/A5JvFEKpUNQ6nuHQ5RuYUm08iqV2zR+K8ICKOW2s2QtUSaEAnh/ZWbqURNTYurb2wbZn77DqPh0mqKxfvx7Tpk3DihUrMHjwYHz55Zf45ptvcPr0abRt27bO7RlUiJqecp0e5XqB4lIdtidk4MGoUOlBhvVRUqZD19e3GV33+/ND0cpHhROpuegepEaZTo9r+Vr4erihbYuKAGRscKM5vUCPD26HVX8m1Vj+7fQo3NG5Fa7la5FXXAYfD1eE+N0KW7YIQY7oySHt8c3+RLnLaHKsfcu9wwSVgQMHom/fvli5cqW0rFu3bpg0aRJiYmLq3J5BhYgaSggBTUk51J5uSM0uQoifZ4Ouy7+/7SxW7LkEANJ4j9NpGmTml2B4l4pLUNcLKkKPq1KBgtJy+Hq41bnfM+maGvO2VGXqDiWFAkiMqfglU1Kmw+oDSdh0Ig0b/zEIHm4u+HZ/ovScpdpMjw7D6oPJWPvUQKw/kor0vBL0adscG+Ku4nqB4czHMwa1w2vjuuGF//6FTSfS6tx3dZ9N7QM3FyUOXrqO6I4t8erGeHz4YC/p55epKcHAJY5/t5mlql/mtLW3J3bHNAsG29fGIYJKaWkpmjVrhv/+97+YPHmytHzevHk4ceIE9u7dW2MbrVYLrfbWXwiNRoPQ0FAGFSJqUgq15cgrLoOXyhXlOj1aeFc8uVsIAW25vsZ8N7Up1+mx/mgq+oT6oZWPCoXacrS7+eiBcp0ehaU6qD3rDlK1KS3Xw81FAYVCYRAM45KzsefcNYM5UkrKdGbVn55XjB8OpeD+fqHYd+EaJvUJxu9nszCyWyByi0vh6eaC0nI9sotK8eXey1g0oTvUzdwghMCla4Voo/bAP/59DC+P6Vpjcr+MvBI8ueYIEtI00vxD70yKxJGkbPx8M3iteWIAQvw8pSd7F5fqkF1UigAfFQpKyuGlcoW7qxJlOr10J9I3f1yWbkOvtGB0lxqD0aPC/HA0OQcvjOqMB6NCUaAtR4dW3gY/m8pf3ZU/04FLfkNWvulHZTzUPxT5JeXYEp8OoGJcVfNmbsgrLoNSocDaQykY3rUVFFAgQ1MCnV6Ptv5e6BTo3eC5cIxxiKCSlpaG4OBg/Pnnnxg0aJC0fMmSJVi9ejXOnat5F8GiRYuwePHiGssZVIiIiBxHfYKKbR+paobqXay1TWqzcOFC5OXlSa/U1NTGKJGIiIhkItuEby1btoSLiwsyMjIMlmdlZSEwMNDoNiqVCiqVqjHKIyIiIjsgW4+Ku7s7+vXrh507dxos37lzp8GlICIiImq6ZJ1Cf/78+Zg2bRqioqIQHR2Nr776CikpKZg5c6acZREREZGdkDWoTJkyBTdu3MBbb72F9PR0REZG4tdff0VYWJicZREREZGdkH1m2obgPCpERESOx6Hu+iEiIiIyhUGFiIiI7BaDChEREdktBhUiIiKyWwwqREREZLcYVIiIiMhuMagQERGR3WJQISIiIrsl68y0DVU5V51Go5G5EiIiIjJX5e9tc+acdeigkp+fDwAIDQ2VuRIiIiKqr/z8fKjV6lrbOPQU+nq9HmlpafDx8YFCobDqvjUaDUJDQ5GamuqU0/Pz+Byfsx8jj8+xOfvxAc5/jLY8PiEE8vPzERQUBKWy9lEoDt2jolQqERISYtPv8PX1dco/gJV4fI7P2Y+Rx+fYnP34AOc/RlsdX109KZU4mJaIiIjsFoMKERER2S0GFRNUKhXefPNNqFQquUuxCR6f43P2Y+TxOTZnPz7A+Y/RXo7PoQfTEhERkXNjjwoRERHZLQYVIiIislsMKkRERGS3GFSIiIjIbjGoGLFixQq0b98eHh4e6NevH/744w+5S6ph0aJFUCgUBq/WrVtL64UQWLRoEYKCguDp6Ylhw4YhISHBYB9arRZz5sxBy5Yt4eXlhXvuuQdXrlwxaJOTk4Np06ZBrVZDrVZj2rRpyM3Ntckx7du3DxMmTEBQUBAUCgU2bdpksL4xjyklJQUTJkyAl5cXWrZsiblz56K0tNSmxzdjxowa5/S2225zmOOLiYlB//794ePjg4CAAEyaNAnnzp0zaOPI59Cc43Pkc7hy5Ur07NlTmtwrOjoaW7duldY78rkz9xgd+fxVFxMTA4VCgWeffVZa5rDnUJCBdevWCTc3N/H111+L06dPi3nz5gkvLy+RnJwsd2kG3nzzTdG9e3eRnp4uvbKysqT1S5cuFT4+PmLDhg0iPj5eTJkyRbRp00ZoNBqpzcyZM0VwcLDYuXOnOHbsmBg+fLjo1auXKC8vl9rcfffdIjIyUhw4cEAcOHBAREZGivHjx9vkmH799Vfx6quvig0bNggAYuPGjQbrG+uYysvLRWRkpBg+fLg4duyY2LlzpwgKChKzZ8+26fFNnz5d3H333Qbn9MaNGwZt7Pn4Ro8eLVatWiVOnTolTpw4IcaNGyfatm0rCgoKpDaOfA7NOT5HPoebN28WW7ZsEefOnRPnzp0Tr7zyinBzcxOnTp0SQjj2uTP3GB35/FV1+PBh0a5dO9GzZ08xb948abmjnkMGlWoGDBggZs6cabCsa9eu4uWXX5apIuPefPNN0atXL6Pr9Hq9aN26tVi6dKm0rKSkRKjVavHFF18IIYTIzc0Vbm5uYt26dVKbq1evCqVSKbZt2yaEEOL06dMCgIiNjZXaHDx4UAAQZ8+etcFR3VL9F3ljHtOvv/4qlEqluHr1qtTmhx9+ECqVSuTl5dnk+ISo+Edy4sSJJrdxpOMTQoisrCwBQOzdu1cI4XznsPrxCeF859DPz0988803TnfujB2jEM5x/vLz80WnTp3Ezp07xdChQ6Wg4sjnkJd+qigtLUVcXBxGjRplsHzUqFE4cOCATFWZduHCBQQFBaF9+/Z46KGHcPnyZQBAYmIiMjIyDI5DpVJh6NCh0nHExcWhrKzMoE1QUBAiIyOlNgcPHoRarcbAgQOlNrfddhvUanWj/zwa85gOHjyIyMhIBAUFSW1Gjx4NrVaLuLg4mx7nnj17EBAQgM6dO+Opp55CVlaWtM7Rji8vLw8A4O/vD8D5zmH146vkDOdQp9Nh3bp1KCwsRHR0tNOdO2PHWMnRz9+sWbMwbtw4jBgxwmC5I59Dh34oobVdv34dOp0OgYGBBssDAwORkZEhU1XGDRw4EGvWrEHnzp2RmZmJd955B4MGDUJCQoJUq7HjSE5OBgBkZGTA3d0dfn5+NdpUbp+RkYGAgIAa3x0QENDoP4/GPKaMjIwa3+Pn5wd3d3ebHveYMWPwwAMPICwsDImJiXj99ddx5513Ii4uDiqVyqGOTwiB+fPnY8iQIYiMjJS+t7Le6vU72jk0dnyA45/D+Ph4REdHo6SkBN7e3ti4cSMiIiKkX0DOcO5MHSPg+Odv3bp1OHbsGI4cOVJjnSP//WNQMUKhUBh8FkLUWCa3MWPGSO979OiB6OhodOzYEatXr5YGf1lyHNXbGGsv58+jsY5JjuOeMmWK9D4yMhJRUVEICwvDli1bcO+995rczh6Pb/bs2Th58iT2799fY50znENTx+fo57BLly44ceIEcnNzsWHDBkyfPh179+41+Z2OeO5MHWNERIRDn7/U1FTMmzcPO3bsgIeHh8l2jngOeemnipYtW8LFxaVG4svKyqqRDu2Nl5cXevTogQsXLkh3/9R2HK1bt0ZpaSlycnJqbZOZmVnju65du9boP4/GPKbWrVvX+J6cnByUlZU16nG3adMGYWFhuHDhglSXIxzfnDlzsHnzZuzevRshISHScmc5h6aOzxhHO4fu7u4IDw9HVFQUYmJi0KtXL3zyySdOc+5qO0ZjHOn8xcXFISsrC/369YOrqytcXV2xd+9efPrpp3B1dZX265DnsN6jWpzcgAEDxDPPPGOwrFu3bnY3mLa6kpISERwcLBYvXiwNmnrvvfek9Vqt1uigqfXr10tt0tLSjA6aOnTokNQmNjZW1sG0jXFMlQPB0tLSpDbr1q2z+WDa6q5fvy5UKpVYvXq1QxyfXq8Xs2bNEkFBQeL8+fNG1zvyOazr+IxxtHNY3Z133immT5/u8OfOnGM0xpHOn0ajEfHx8QavqKgo8eijj4r4+HiHPocMKtVU3p787bffitOnT4tnn31WeHl5iaSkJLlLM/D888+LPXv2iMuXL4vY2Fgxfvx44ePjI9W5dOlSoVarxU8//STi4+PF1KlTjd6GFhISInbt2iWOHTsm7rzzTqO3ofXs2VMcPHhQHDx4UPTo0cNmtyfn5+eL48ePi+PHjwsA4qOPPhLHjx+Xbg1vrGOqvLXurrvuEseOHRO7du0SISEhDb51sLbjy8/PF88//7w4cOCASExMFLt37xbR0dEiODjYYY7vmWeeEWq1WuzZs8fg9s6ioiKpjSOfw7qOz9HP4cKFC8W+fftEYmKiOHnypHjllVeEUqkUO3bsEEI49rkz5xgd/fwZU/WuHyEc9xwyqBixfPlyERYWJtzd3UXfvn0Nbj+0F5X3v7u5uYmgoCBx7733ioSEBGm9Xq8Xb775pmjdurVQqVTijjvuEPHx8Qb7KC4uFrNnzxb+/v7C09NTjB8/XqSkpBi0uXHjhnjkkUeEj4+P8PHxEY888ojIycmxyTHt3r1bAKjxqvzfTmMeU3Jyshg3bpzw9PQU/v7+Yvbs2aKkpMRmx1dUVCRGjRolWrVqJdzc3ETbtm3F9OnTa9Ruz8dn7NgAiFWrVkltHPkc1nV8jn4On3jiCenfvVatWom77rpLCilCOPa5M+cYHf38GVM9qDjqOVQIIUT9LxgRERER2R4H0xIREZHdYlAhIiIiu8WgQkRERHaLQYWIiIjsFoMKERER2S0GFSIiIrJbDCpERERktxhUiIiIyG4xqBBRg7Rr1w7Lli0zu/2ePXugUCiQm5trs5qIyHlwZlqiJmbYsGHo3bt3vcJFba5duwYvLy80a9bMrPalpaXIzs5GYGCgxY+0b6g9e/Zg+PDhyMnJQfPmzWWpgYjM4yp3AURkf4QQ0Ol0cHWt+5+IVq1a1Wvf7u7uaN26taWlEVETw0s/RE3IjBkzsHfvXnzyySdQKBRQKBRISkqSLsds374dUVFRUKlU+OOPP3Dp0iVMnDgRgYGB8Pb2Rv/+/bFr1y6DfVa/9KNQKPDNN99g8uTJaNasGTp16oTNmzdL66tf+vnuu+/QvHlzbN++Hd26dYO3tzfuvvtupKenS9uUl5dj7ty5aN68OVq0aIGXXnoJ06dPx6RJk0wea3JyMiZMmAA/Pz94eXmhe/fu+PXXX5GUlIThw4cDAPz8/KBQKDBjxgwAFQHt/fffR4cOHeDp6YlevXrhxx9/rFH7li1b0KtXL3h4eGDgwIGIj4+38IwQUV0YVIiakE8++QTR0dF46qmnkJ6ejvT0dISGhkrrX3zxRcTExODMmTPo2bMnCgoKMHbsWOzatQvHjx/H6NGjMWHCBKSkpNT6PYsXL8aDDz6IkydPYuzYsXjkkUeQnZ1tsn1RURE+/PBD/Otf/8K+ffuQkpKCF154QVr/3nvv4d///jdWrVqFP//8ExqNBps2baq1hlmzZkGr1WLfvn2Ij4/He++9B29vb4SGhmLDhg0AgHPnziE9PR2ffPIJAOC1117DqlWrsHLlSiQkJOC5557Do48+ir179xrse8GCBfjwww9x5MgRBAQE4J577kFZWVmt9RCRhSx65jIROazqj34XQojdu3cLAGLTpk11bh8RESE+++wz6XNYWJj4+OOPpc8AxGuvvSZ9LigoEAqFQmzdutXguyofC79q1SoBQFy8eFHaZvny5SIwMFD6HBgYKD744APpc3l5uWjbtq2YOHGiyTp79OghFi1aZHRd9Roq6/Tw8BAHDhwwaPu3v/1NTJ061WC7devWSetv3LghPD09xfr1603WQkSW4xgVIpJERUUZfC4sLMTixYvxyy+/IC0tDeXl5SguLq6zR6Vnz57Sey8vL/j4+CArK8tk+2bNmqFjx47S5zZt2kjt8/LykJmZiQEDBkjrXVxc0K9fP+j1epP7nDt3Lp555hns2LEDI0aMwH333WdQV3WnT59GSUkJRo4cabC8tLQUffr0MVgWHR0tvff390eXLl1w5swZk/smIssxqBCRxMvLy+DzggULsH37dnz44YcIDw+Hp6cn7r//fpSWlta6Hzc3N4PPCoWi1lBhrL2odkNi9TuEqq+v7sknn8To0aOxZcsW7NixAzExMfi///s/zJkzx2j7yvq2bNmC4OBgg3UqlarW7zJWHxFZB8eoEDUx7u7u0Ol0ZrX9448/MGPGDEyePBk9evRA69atkZSUZNsCq1Gr1QgMDMThw4elZTqdDsePH69z29DQUMycORM//fQTnn/+eXz99dcAKn4GlfupFBERAZVKhZSUFISHhxu8qo7jAYDY2FjpfU5ODs6fP4+uXbs26DiJyDj2qBA1Me3atcOhQ4eQlJQEb29v+Pv7m2wbHh6On376CRMmTIBCocDrr79ea8+IrcyZMwcxMTEIDw9H165d8dlnnyEnJ6fWXoxnn30WY8aMQefOnZGTk4Pff/8d3bp1AwCEhYVBoVDgl19+wdixY+Hp6QkfHx+88MILeO6556DX6zFkyBBoNBocOHAA3t7emD59urTvt956Cy1atEBgYCBeffVVtGzZstY7kIjIcuxRIWpiXnjhBbi4uCAiIgKtWrWqdbzJxx9/DD8/PwwaNAgTJkzA6NGj0bdv30astsJLL72EqVOn4rHHHkN0dDS8vb0xevRoeHh4mNxGp9Nh1qxZ6NatG+6++2506dIFK1asAAAEBwdj8eLFePnllxEYGIjZs2cDAN5++2288cYbiImJQbdu3TB69Gj873//Q/v27Q32vXTpUsybNw/9+vVDeno6Nm/eLPXSEJF1cWZaInI4er0e3bp1w4MPPoi333670b6XM9oSNT5e+iEiu5ecnIwdO3Zg6NCh0Gq1+Pzzz5GYmIiHH35Y7tKIyMZ46YeI7J5SqcR3332H/v37Y/DgwYiPj8euXbukMSdE5Lx46YeIiIjsFntUiIiIyG4xqBAREZHdYlAhIiIiu8WgQkRERHaLQYWIiIjsFoMKERER2S0GFSIiIrJbDCpERERkt/4fP4lmDpzGe7IAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from tqdm import trange\n",
    "import matplotlib.pyplot as plt\n",
    "from torch.optim import Adam\n",
    "import numpy as np\n",
    "\n",
    "# 训练序列到序列模型\n",
    "def train_seq2seq_mt(train_data, encoder, decoder, epochs=20,\\\n",
    "        learning_rate=1e-3):\n",
    "    # 准备模型和优化器\n",
    "    encoder_optimizer = Adam(encoder.parameters(), lr=learning_rate)\n",
    "    decoder_optimizer = Adam(decoder.parameters(), lr=learning_rate)\n",
    "    criterion = nn.NLLLoss()\n",
    "\n",
    "    encoder.train()\n",
    "    decoder.train()\n",
    "    encoder.zero_grad()\n",
    "    decoder.zero_grad()\n",
    "\n",
    "    step_losses = []\n",
    "    plot_losses = []\n",
    "    with trange(n_epochs, desc='epoch', ncols=60) as pbar:\n",
    "        for epoch in pbar:\n",
    "            np.random.shuffle(train_data)\n",
    "            for step, data in enumerate(train_data):\n",
    "                # 将源序列和目标序列转为 1 * seq_len 的tensor\n",
    "                # 这里为了简单实现，采用了批次大小为1，\n",
    "                # 当批次大小大于1时，编码器需要进行填充\n",
    "                # 并且返回最后一个非填充词的隐状态，\n",
    "                # 解码也需要进行相应的处理\n",
    "                input_ids, target_ids = data\n",
    "                input_tensor, target_tensor = \\\n",
    "                    torch.tensor(input_ids).unsqueeze(0),\\\n",
    "                    torch.tensor(target_ids).unsqueeze(0)\n",
    "\n",
    "                encoder_optimizer.zero_grad()\n",
    "                decoder_optimizer.zero_grad()\n",
    "\n",
    "                encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "                # 输入目标序列用于teacher forcing训练\n",
    "                decoder_outputs, _, _ = decoder(encoder_outputs,\\\n",
    "                    encoder_hidden, target_tensor)\n",
    "\n",
    "                loss = criterion(\n",
    "                    decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
    "                    target_tensor.view(-1)\n",
    "                )\n",
    "                pbar.set_description(f'epoch-{epoch}, '+\\\n",
    "                    f'loss={loss.item():.4f}')\n",
    "                step_losses.append(loss.item())\n",
    "                # 实际训练批次为1，训练损失波动过大\n",
    "                # 将多步损失求平均可以得到更平滑的训练曲线，便于观察\n",
    "                plot_losses.append(np.mean(step_losses[-32:]))\n",
    "                loss.backward()\n",
    "\n",
    "                encoder_optimizer.step()\n",
    "                decoder_optimizer.step()\n",
    "\n",
    "    plot_losses = np.array(plot_losses)\n",
    "    plt.plot(range(len(plot_losses)), plot_losses)\n",
    "    plt.xlabel('training step')\n",
    "    plt.ylabel('loss')\n",
    "    plt.show()\n",
    "\n",
    "    \n",
    "hidden_size = 128\n",
    "n_epochs = 20\n",
    "learning_rate = 1e-3\n",
    "\n",
    "encoder = RNNEncoder(input_lang.n_words, hidden_size)\n",
    "decoder = AttnRNNDecoder(output_lang.n_words, hidden_size)\n",
    "\n",
    "train_seq2seq_mt(train_data, encoder, decoder, n_epochs, learning_rate)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be47e115",
   "metadata": {},
   "source": [
    "下面实现贪心搜索解码"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "678192b3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： 智 能 新 零 售 新 场 景 新 科 技 新 物 流 新 消 费\n",
      "target： smart new retail , new scene , new technology , new logistics , new consumption .\n",
      "pred： smart new retail , new scene , new technology , new logistics , new scene , new commercial realization , new logistics , new consumption .\n",
      "\n",
      "input： 金 蝶 k i s — — 财 务 软 件 培 训 教 程 ( 第 3 版 )\n",
      "target： golden fluttershy kis — financial software training curriculum (version 3)\n",
      "pred： golden fluttershy kis — financial software training curriculum (version 3)\n",
      "\n",
      "input： 动 态 集 二 次 元 动 漫 人 体 素 材 多 人 篇\n",
      "target： momentum set , binary comics , multiplayers .\n",
      "pred： momentum set , binary comics , multiplayers .\n",
      "\n",
      "input： 酒 店 管 理 实 操 从 入 门 到 精 通\n",
      "target： hotel management , from entry to mastery .\n",
      "pred： from entry to mastery of the value-added version of wealth .\n",
      "\n",
      "input： 微 服 务 中 台 架 构 开 发\n",
      "target： medium-level infrastructure development for micro-services\n",
      "pred： medium-level infrastructure development for micro-services\n",
      "\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码修改自GitHub项目pytorch/tutorials\n",
    "（Copyright (c) 2023, PyTorch, BSD-3-Clause License（见附录））\n",
    "\"\"\"\n",
    "def greedy_decode(encoder, decoder, sentence, input_lang, output_lang):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "        \n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        decoder_outputs, decoder_hidden, decoder_attn = \\\n",
    "            decoder(encoder_outputs, encoder_hidden)\n",
    "        \n",
    "        # 取出每一步预测概率最大的词\n",
    "        _, topi = decoder_outputs.topk(1)\n",
    "        \n",
    "        decoded_ids = []\n",
    "        for idx in topi.squeeze():\n",
    "            if idx.item() == EOS_token:\n",
    "                break\n",
    "            decoded_ids.append(idx.item())\n",
    "    return output_lang.ids2sent(decoded_ids), decoder_attn\n",
    "            \n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence, _ = greedy_decode(encoder, decoder, pair[0],\n",
    "        input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e38320f0",
   "metadata": {},
   "source": [
    "\n",
    "接下来使用束搜索解码来验证模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "3496efa3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input： p y t h o n 程 序 设 计 基 础 教 程 （ 慕 课 版 ）\n",
      "target： python program design basic curriculum (curriculus version)\n",
      "pred： python program design (video curriculum (curriculus version)\n",
      "\n",
      "input： 极 简 抗 压 行 动 法 高 效 能 人 士 如 何 管 理 压 力\n",
      "target： it's a very simple anti-pressure method .\n",
      "pred： it's a very simple anti-pressure method .\n",
      "\n",
      "input： 用 友 e r p - u 8 业 务 、 财 务 模 拟 实 战 （ v 1 3 版 ）\n",
      "target： business , financial simulation field operations with friend erp-u8 (v13)\n",
      "pred： business , financial simulation field operations with friend erp-u8 (v13)\n",
      "\n",
      "input： s p a r k 大 数 据 技 术 与 应 用\n",
      "target： spark big data technology and applications\n",
      "pred： technology and big data technology and applications based on techniques and technology foundation\n",
      "\n",
      "input： u n i t y 案 例 开 发 大 全 第 2 版\n",
      "target： unity case development full , 2nd ed .\n",
      "pred： unity case development full , 2nd ed .\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 定义容器类用于管理所有的候选结果\n",
    "class BeamHypotheses:\n",
    "    def __init__(self, num_beams, max_length):\n",
    "        self.max_length = max_length\n",
    "        self.num_beams = num_beams\n",
    "        self.beams = []\n",
    "        self.worst_score = 1e9\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.beams)\n",
    "    \n",
    "    # 添加一个候选结果，更新最差得分\n",
    "    def add(self, sum_logprobs, hyp, hidden):\n",
    "        score = sum_logprobs / max(len(hyp), 1)\n",
    "        if len(self) < self.num_beams or score > self.worst_score:\n",
    "            # 可更新的情况：数量未饱和或超过最差得分\n",
    "            self.beams.append((score, hyp, hidden))\n",
    "            if len(self) > self.num_beams:\n",
    "                # 数量饱和需要删掉一个最差的\n",
    "                sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                    (s, _, _) in enumerate(self.beams)])\n",
    "                del self.beams[sorted_scores[0][1]]\n",
    "                self.worst_score = sorted_scores[1][0]\n",
    "            else:\n",
    "                self.worst_score = min(score, self.worst_score)\n",
    "    \n",
    "    # 取出一个未停止的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        for i, (s, hyp, hid) in enumerate(self.beams):\n",
    "            # 未停止的候选结果需满足：长度小于最大解码长度；不以<eos>结束\n",
    "            if len(hyp) < self.max_length and (len(hyp) == 0\\\n",
    "                    or hyp[-1] != EOS_token):\n",
    "                del self.beams[i]\n",
    "                if len(self) > 0:\n",
    "                    sorted_scores = sorted([(s, idx) for idx,\\\n",
    "                        (s, _, _) in enumerate(self.beams)])\n",
    "                    self.worst_score = sorted_scores[0][0]\n",
    "                else:\n",
    "                    self.worst_score = 1e9\n",
    "                return True, (s, hyp, hid)\n",
    "        return False, None\n",
    "    \n",
    "    # 取出分数最高的候选结果，第一个返回值表示是否成功取出，\n",
    "    # 如成功，则第二个值为目标候选结果\n",
    "    def pop_best(self):\n",
    "        if len(self) == 0:\n",
    "            return False, None\n",
    "        sorted_scores = sorted([(s, idx) for idx, (s, _, _)\\\n",
    "            in enumerate(self.beams)])\n",
    "        return True, self.beams[sorted_scores[-1][1]]\n",
    "\n",
    "\n",
    "def beam_search_decode(encoder, decoder, sentence, input_lang,\n",
    "        output_lang, num_beams=3):\n",
    "    with torch.no_grad():\n",
    "        # 将源序列转为 1 * seq_length 的tensor\n",
    "        input_ids = input_lang.sent2ids(sentence)\n",
    "        input_tensor = torch.tensor(input_ids).unsqueeze(0)\n",
    "\n",
    "        # 在容器中插入一个空的候选结果\n",
    "        encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
    "        init_hyp = []\n",
    "        hypotheses = BeamHypotheses(num_beams, MAX_LENGTH)\n",
    "        hypotheses.add(0, init_hyp, encoder_hidden)\n",
    "\n",
    "        while True:\n",
    "            # 每次取出一个未停止的候选结果\n",
    "            flag, item = hypotheses.pop()\n",
    "            if not flag:\n",
    "                break\n",
    "                \n",
    "            score, hyp, decoder_hidden = item\n",
    "            \n",
    "            # 当前解码器输入\n",
    "            if len(hyp) > 0:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(hyp[-1])\n",
    "            else:\n",
    "                decoder_input = torch.empty(1, 1,\\\n",
    "                    dtype=torch.long).fill_(SOS_token)\n",
    "\n",
    "            # 解码一步\n",
    "            decoder_output, decoder_hidden, _ = decoder.forward_step(\n",
    "                decoder_input, decoder_hidden, encoder_outputs\n",
    "            )\n",
    "\n",
    "            # 从输出分布中取出前k个结果\n",
    "            topk_values, topk_ids = decoder_output.topk(num_beams)\n",
    "            # 生成并添加新的候选结果到容器\n",
    "            for logp, token_id in zip(topk_values.squeeze(),\\\n",
    "                    topk_ids.squeeze()):\n",
    "                sum_logprobs = score * len(hyp) + logp.item()\n",
    "                new_hyp = hyp + [token_id.item()]\n",
    "                hypotheses.add(sum_logprobs, new_hyp, decoder_hidden)\n",
    "\n",
    "        flag, item = hypotheses.pop_best()\n",
    "        if flag:\n",
    "            hyp = item[1]\n",
    "            if hyp[-1] == EOS_token:\n",
    "                del hyp[-1]\n",
    "            return output_lang.ids2sent(hyp)\n",
    "        else:\n",
    "            return ''\n",
    "\n",
    "encoder.eval()\n",
    "decoder.eval()\n",
    "for i in range(5):\n",
    "    pair = random.choice(pairs)\n",
    "    print('input：', pair[0])\n",
    "    print('target：', pair[1])\n",
    "    output_sentence = beam_search_decode(encoder, decoder,\\\n",
    "        pair[0], input_lang, output_lang)\n",
    "    print('pred：', output_sentence)\n",
    "    print('')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9e6b0958-eee2-479d-90d6-dc4fe57d9027",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
